# -*-Python-*-
# Transformer with some of the feed-forward layers replaced by
# mixture-of-experts layers.
# 3 layers * 32 experts * (1k*4K + 4K*1K) = 768MiParameters

utils.run.model_type = "lm"

import mesh_tensorflow.transformer.moe

# for Unitransformer models (e.g. language-model)
transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.moe.MoE1D,
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
]
num_layers = 3
MoE1D.num_experts = 32
MoE1D.hidden_size = 4096
d_ff = 8192
num_heads = 16
